{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88e11407-68bb-4075-922d-a50125eeaac2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install celltypist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd684c3-173f-4662-8575-e4c5411b4fa7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from os.path import join\n",
    "\n",
    "import anndata\n",
    "import scanpy as sc\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import dask.dataframe as dd\n",
    "import dask.array as da\n",
    "\n",
    "from scipy.sparse import csr_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01af4fbc-bd8d-4a08-bb25-4fb66487e2a9",
   "metadata": {},
   "source": [
    "# Get subset training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3efabd1-ec3a-4550-be1f-cf8a8b4d119d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_count_matrix_and_obs(ddf):\n",
    "    x = (\n",
    "        ddf['X']\n",
    "        .map_partitions(\n",
    "            lambda xx: pd.DataFrame(np.vstack(xx.tolist())), \n",
    "            meta={col: 'f4' for col in range(19331)}\n",
    "        )\n",
    "        .to_dask_array(lengths=[1024] * ddf.npartitions)\n",
    "    )\n",
    "    obs = ddf[['cell_type']].compute()\n",
    "    \n",
    "    return x, obs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fcc7d3b-f557-4e2a-914b-aba95e9b7feb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0b716d9-65e9-48c3-a799-338afe2473ad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ddf = dd.read_parquet(join(PATH, 'train'), split_row_groups=True)\n",
    "x, obs = get_count_matrix_and_obs(ddf)\n",
    "var = pd.read_parquet(join(PATH, 'var.parquet'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31651587-9547-40fa-a180-7a40085dd677",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "start = 0\n",
    "subsample_size = 1_500_000\n",
    "# data is already shuffled -> just take first x cells\n",
    "# data is already normalized\n",
    "adata_train = anndata.AnnData(\n",
    "    X=x[start:start+subsample_size].map_blocks(csr_matrix).compute(), \n",
    "    obs=obs.iloc[start:start+subsample_size],\n",
    "    var=var.set_index('feature_name')\n",
    ")\n",
    "\n",
    "adata_train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c83cf482-4f87-478a-b3de-69e834223df5",
   "metadata": {},
   "source": [
    "# Fit celltyist model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e82fe7a-2ff5-4a83-aa97-a8a04647376f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import celltypist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80fa0837-7363-4798-8539-275ba20040b5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "new_model = celltypist.train(\n",
    "    adata_train, \n",
    "    labels='cell_type', \n",
    "    n_jobs=20, \n",
    "    feature_selection=True,\n",
    "    use_SGD=True, \n",
    "    mini_batch=True,\n",
    "    batch_number=1500,\n",
    "    with_mean=False,\n",
    "    random_state=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12231d4b-0d54-467e-9444-b888d809fbf1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "new_model.write(f'/mnt/dssfs02/tb_logs/cxg_2023_05_15_celltypist/model_{subsample_size}_cells_run1.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ae7117f-80f2-4c8e-8049-bdde8a6aa397",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe3044e-5946-4279-a50d-f144ce0d6205",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
